1D Convolution을 기본 구성 요소로 하는 EEG classifier를 학습해보는 노트북.
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
# Load some packages
import os
import glob
import json
import datetime
from copy import deepcopy
import matplotlib.pyplot as plt
import pprint
from IPython.display import clear_output
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import label_binarize
from tqdm.auto import tqdm
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from typing import Type, Any, Callable, Union, List, Optional
from itertools import cycle
# custom package
from utils.eeg_dataset import *
# notebook name
def get_notebook_name():
import ipynbname
return ipynbname.name()
nb_fname = get_notebook_name()
# Other settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # cleaner text
plt.style.use('default')
# ['Solarize_Light2', '_classic_test_patch', 'bmh', 'classic', 'dark_background', 'fast',
# 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn', 'seaborn-bright', 'seaborn-colorblind',
# 'seaborn-dark', 'seaborn-dark-palette', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted',
# 'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk',
# 'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'tableau-colorblind10']
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams["font.family"] = 'NanumGothic' # for Hangul in Windows
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')
PyTorch version: 1.9.0 cuda is available.
# Data file path
data_path = r'dataset/02_Curated_Data/'
meta_path = os.path.join(data_path, 'metadata_debug.json')
with open(meta_path, 'r') as json_file:
metadata = json.load(json_file)
pprint.pprint(metadata[0])
{'age': 78,
'birth': '1940-06-02',
'dx1': 'mci_rf',
'edfname': '00001809_261018',
'events': [[0, 'Start Recording'],
[0, 'New Montage - Montage 002'],
[36396, 'Eyes Open'],
[72518, 'Eyes Closed'],
[73862, 'Eyes Open'],
[75248, 'Eyes Closed'],
[76728, 'swallowing'],
[77978, 'Eyes Open'],
[79406, 'Eyes Closed'],
[79996, 'Photic On - 3.0 Hz'],
[80288, 'Eyes Open'],
[81296, 'Eyes Closed'],
[82054, 'Photic Off'],
[84070, 'Photic On - 6.0 Hz'],
[84488, 'Eyes Open'],
[85538, 'Eyes Closed'],
[86086, 'Photic Off'],
[88144, 'Photic On - 9.0 Hz'],
[90160, 'Photic Off'],
[91458, 'Eyes Open'],
[92218, 'Photic On - 12.0 Hz'],
[92762, 'Eyes Closed'],
[94198, 'Photic Off'],
[94742, 'Eyes Open'],
[95708, 'Eyes Closed'],
[96256, 'Photic On - 15.0 Hz'],
[98272, 'Photic Off'],
[100330, 'Photic On - 18.0 Hz'],
[102346, 'Photic Off'],
[102596, 'Eyes Open'],
[103856, 'Eyes Closed'],
[104361, 'Photic On - 21.0 Hz'],
[106420, 'Photic Off'],
[106880, 'Eyes Open'],
[107804, 'Eyes Closed'],
[108435, 'Photic On - 24.0 Hz'],
[110452, 'Photic Off'],
[111080, 'Eyes Open'],
[112004, 'Eyes Closed'],
[112509, 'Photic On - 27.0 Hz'],
[114528, 'Photic Off'],
[114864, 'Eyes Open'],
[116124, 'Eyes Closed'],
[116544, 'Photic On - 30.0 Hz'],
[118602, 'Photic Off'],
[126672, 'artifact'],
[134030, 'Move'],
[135584, 'Eyes Open'],
[136668, 'Eyes Closed'],
[139818, 'Eyes Open'],
[141414, 'Eyes Closed'],
[145000, 'Paused']],
'label': ['mci', 'mci_amnestic', 'mci_amnestic_rf'],
'record': '2018-10-26T15:46:26',
'serial': '00001'}
diagnosis_filter = [
# Normal
{'type': 'Normal',
'include': ['normal'],
'exclude': []},
# Non-vascular MCI
{'type': 'Non-vascular MCI',
'include': ['mci'],
'exclude': ['mci_vascular']},
# Non-vascular dementia
{'type': 'Non-vascular dementia',
'include': ['dementia'],
'exclude': ['vd']},
]
def generate_class_label(label):
for c, f in enumerate(diagnosis_filter):
inc = set(f['include']) & set(label) == set(f['include'])
# inc = len(set(f['include']) & set(label)) > 0
exc = len(set(f['exclude']) & set(label)) == 0
if inc and exc:
return (c, f['type'])
return (-1, 'The others')
class_label_to_type = [d_f['type'] for d_f in diagnosis_filter]
print('class_label_to_type:', class_label_to_type)
class_label_to_type: ['Normal', 'Non-vascular MCI', 'Non-vascular dementia']
splitted_metadata = [[] for i in diagnosis_filter]
for m in metadata:
c, n = generate_class_label(m['label'])
if c >= 0:
m['class_type'] = n
m['class_label'] = c
splitted_metadata[c].append(m)
for i, split in enumerate(splitted_metadata):
if len(split) == 0:
print(f'(Warning) Split group {i} has no data.')
else:
print(f'- There are {len(split):} data belonging to {split[0]["class_type"]}')
- There are 463 data belonging to Normal - There are 347 data belonging to Non-vascular MCI - There are 229 data belonging to Non-vascular dementia
# random seed
random.seed(0)
# Train : Val : Test = 8 : 1 : 1
ratio1 = 0.8
ratio2 = 0.1
metadata_train = []
metadata_val = []
metadata_test = []
for split in splitted_metadata:
random.shuffle(split)
n1 = round(len(split) * ratio1)
n2 = n1 + round(len(split) * ratio2)
metadata_train.extend(split[:n1])
metadata_val.extend(split[n1:n2])
metadata_test.extend(split[n2:])
random.shuffle(metadata_train)
random.shuffle(metadata_val)
random.shuffle(metadata_test)
print('Train data size\t\t:', len(metadata_train))
print('Validation data size\t:', len(metadata_val))
print('Test data size\t\t:', len(metadata_test))
print('\n', '--- Recheck ---', '\n')
train_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_train:
train_class_nums[m['class_label']] += 1
val_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_val:
val_class_nums[m['class_label']] += 1
test_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_test:
test_class_nums[m['class_label']] += 1
print('Train data label distribution\t:', train_class_nums, train_class_nums.sum())
print('Val data label distribution\t:', val_class_nums, val_class_nums.sum())
print('Test data label distribution\t:', test_class_nums, test_class_nums.sum())
# random seed
random.seed()
# print([m['serial'] for m in metadata_train[:15]])
# print([m['serial'] for m in metadata_val[:15]])
# print([m['serial'] for m in metadata_test[:15]])
Train data size : 831 Validation data size : 104 Test data size : 104 --- Recheck --- Train data label distribution : [370 278 183] 831 Val data label distribution : [46 35 23] 104 Test data label distribution : [47 34 23] 104
ages = []
for m in metadata_train:
ages.append(m['age'])
ages = np.array(ages)
age_mean = np.mean(ages)
age_std = np.std(ages)
print('Age mean and standard deviation:')
print(age_mean, age_std)
Age mean and standard deviation: 69.92779783393502 9.817569889945597
crop_length = 200*60 # 1 minutes
composed = transforms.Compose([EEGNormalizeAge(mean=age_mean, std=age_std),
EEGDropPhoticChannel(),
EEGRandomCrop(crop_length=crop_length),
EEGNormalizePerSignal(),
EEGToTensor()])
train_dataset = EEGDataset(data_path, metadata_train, composed)
val_dataset = EEGDataset(data_path, metadata_val, composed)
test_dataset = EEGDataset(data_path, metadata_test, composed)
print(train_dataset[0]['signal'].shape)
print(train_dataset[0])
print()
print('-' * 100)
print()
print(val_dataset[0]['signal'].shape)
print(val_dataset[0])
print()
print('-' * 100)
print()
print(test_dataset[0]['signal'].shape)
print(test_dataset[0])
torch.Size([20, 12000])
{'signal': tensor([[-0.6351, -0.5564, -0.5564, ..., -0.3833, -0.2418, -0.2418],
[ 0.0460, 0.1812, 0.4065, ..., 0.4516, 0.4516, 0.3615],
[-0.4076, -0.2650, -0.1224, ..., -0.3600, -0.2650, -0.1224],
...,
[ 1.0567, 1.0567, 0.9386, ..., -1.6597, -1.6597, -1.5416],
[ 0.4995, 0.3253, 0.0640, ..., -0.4586, -0.3715, -0.2844],
[ 0.4015, 0.3808, 0.3946, ..., 0.3462, 0.3738, 0.3531]]), 'age': tensor(-1.2149), 'class_label': tensor(0), 'metadata': {'serial': '01012', 'edfname': '01212635_270515', 'birth': '1956-06-01', 'record': '2015-05-27T09:37:24', 'age': 58, 'dx1': 'cb_normal', 'label': ['normal', 'cb_normal'], 'events': [[0, 'Start Recording'], [0, 'New Montage - Montage 002'], [400, 'Eyes Open'], [7918, 'Eyes Closed'], [14091, 'Eyes Open'], [18208, 'Eyes Closed'], [24256, 'Eyes Open'], [30724, 'Eyes Closed'], [36562, 'Eyes Open'], [42190, 'Eyes Closed'], [48910, 'Eyes Open'], [55126, 'Eyes Closed'], [60417, 'Eyes Open'], [66004, 'Eyes Closed'], [71968, 'Eyes Open'], [78310, 'Eyes Closed'], [84442, 'Eyes Open'], [90070, 'Eyes Closed'], [96076, 'Eyes Open'], [102082, 'Eyes Closed'], [108844, 'Eyes Open'], [113674, 'Eyes Closed'], [120000, 'Paused']], 'class_type': 'Normal', 'class_label': 0}}
----------------------------------------------------------------------------------------------------
torch.Size([20, 12000])
{'signal': tensor([[-8.1465e-02, -1.0128e-01, -1.8056e-01, ..., -1.6274e+00,
-1.7463e+00, -1.7662e+00],
[-1.0585e-02, -1.0585e-02, -1.0585e-02, ..., -8.0078e-01,
-1.0971e+00, -1.3934e+00],
[ 9.9020e-01, 6.6443e-01, 5.0155e-01, ..., 5.0155e-01,
-3.1287e-01, -1.2902e+00],
...,
[ 1.7799e-03, -2.0164e-01, -2.0164e-01, ..., 6.1204e-01,
1.7799e-03, -2.0164e-01],
[ 9.4771e-01, 7.5592e-01, 5.6413e-01, ..., 2.0985e+00,
2.2903e+00, 2.0985e+00],
[-1.0994e-01, -1.6587e-01, -2.1381e-01, ..., -1.0687e+00,
7.6093e-01, 1.6478e+00]]), 'age': tensor(0.7204), 'class_label': tensor(1), 'metadata': {'serial': '00700', 'edfname': '00985401_011117', 'birth': '1940-09-09', 'record': '2017-11-01T14:20:48', 'age': 77, 'dx1': 'mci amnestic', 'label': ['mci', 'mci_amnestic'], 'events': [[0, 'Start Recording'], [0, 'New Montage - Montage 002'], [1705, 'Eyes Open'], [5402, 'Eyes Closed'], [13046, 'Eyes Open'], [17666, 'Eyes Closed'], [30308, 'Eyes Closed'], [36272, 'Eyes Open'], [41774, 'Eyes Closed'], [48958, 'Eyes Open'], [55510, 'Eyes Closed'], [61641, 'Eyes Open'], [66766, 'Eyes Closed'], [72730, 'Eyes Open'], [78988, 'Eyes Closed'], [87770, 'Eyes Open'], [90542, 'Eyes Closed'], [97096, 'Eyes Open'], [102178, 'Eyes Closed'], [110872, 'Eyes Open'], [113728, 'Eyes Closed'], [122052, 'Photic On - 3.0 Hz'], [122428, 'Eyes Open'], [123309, 'Eyes Closed'], [124068, 'Photic Off'], [126126, 'Photic On - 6.0 Hz'], [128142, 'Photic Off'], [130158, 'Photic On - 9.0 Hz'], [132216, 'Photic Off'], [132718, 'Eyes Open'], [133600, 'Eyes Closed'], [134232, 'Photic On - 12.0 Hz'], [136248, 'Photic Off'], [138306, 'Photic On - 15.0 Hz'], [140322, 'Photic Off'], [142380, 'Photic On - 18.0 Hz'], [142630, 'Eyes Open'], [143302, 'Eyes Closed'], [144396, 'Photic Off'], [146412, 'Photic On - 21.0 Hz'], [148428, 'Photic Off'], [150486, 'Photic On - 24.0 Hz'], [152502, 'Photic Off'], [152710, 'Eyes Open'], [153550, 'Eyes Closed'], [154560, 'Photic On - 27.0 Hz'], [156576, 'Photic Off'], [158592, 'Photic On - 30.0 Hz'], [160608, 'Photic Off'], [160942, 'Eyes Open'], [161698, 'Eyes Closed'], [169132, 'Eyes Open'], [170098, 'Eyes Closed'], [173600, 'Paused']], 'class_type': 'Non-vascular MCI', 'class_label': 1}}
----------------------------------------------------------------------------------------------------
torch.Size([20, 12000])
{'signal': tensor([[ 0.0611, 0.0397, -0.0033, ..., 0.4045, 0.3616, 0.3401],
[ 0.5079, 0.5079, 0.5079, ..., -1.6499, -1.8554, -1.9582],
[-0.6028, -0.6028, -0.4064, ..., 0.3790, 0.3790, 0.3790],
...,
[ 1.5874, 1.8140, 1.3607, ..., -0.2258, -0.4524, -0.9057],
[ 0.5565, 0.5565, 0.5565, ..., 1.6844, 1.8724, 1.6844],
[ 0.1989, 0.0538, -0.0244, ..., 0.3999, 0.1840, 0.1506]]), 'age': tensor(1.0259), 'class_label': tensor(0), 'metadata': {'serial': '00299', 'edfname': '00671212_160819', 'birth': '1938-08-17', 'record': '2019-08-16T10:57:03', 'age': 80, 'dx1': 'smi', 'label': ['normal', 'smi'], 'events': [[0, 'Start Recording'], [0, 'New Montage - Montage 005'], [1773, 'Eyes Closed'], [6000, 'Cz check'], [7612, 'Eyes Open'], [12912, 'Eyes Closed'], [18078, 'Eyes Open'], [23958, 'Eyes Closed'], [29288, 'Eyes Open'], [35934, 'Eyes Closed'], [41856, 'Eyes Open'], [47862, 'Eyes Closed'], [54460, 'Eyes Open'], [59962, 'Eyes Closed'], [66178, 'Eyes Open'], [71008, 'Eyes Closed'], [73948, 'Photic On - 3.0 Hz'], [74158, 'Eyes Open'], [75166, 'Eyes Closed'], [75964, 'Photic Off'], [77980, 'Photic On - 6.0 Hz'], [78358, 'Eyes Open'], [79282, 'Eyes Closed'], [80038, 'Photic Off'], [82054, 'Photic On - 9.0 Hz'], [84070, 'Photic Off'], [86128, 'Photic On - 12.0 Hz'], [87640, 'Eyes Open'], [88144, 'Photic Off'], [88396, 'Eyes Closed'], [90202, 'Photic On - 15.0 Hz'], [92218, 'Photic Off'], [92722, 'Eyes Open'], [93772, 'Eyes Closed'], [94234, 'Photic On - 18.0 Hz'], [96250, 'Photic Off'], [98308, 'Photic On - 21.0 Hz'], [100324, 'Photic Off'], [102382, 'Photic On - 24.0 Hz'], [104398, 'Photic Off'], [106414, 'Photic On - 27.0 Hz'], [108430, 'Photic Off'], [110488, 'Photic On - 30.0 Hz'], [111580, 'Eyes Open'], [111790, 'Photic Off'], [112420, 'Eyes Closed'], [113200, 'Paused']], 'class_type': 'Normal', 'class_label': 0}}
print('Current PyTorch device:', device)
if device.type == 'cuda':
num_workers = 0 # A number other than 0 causes an error
pin_memory = True
else:
num_workers = 0
pin_memory = False
train_loader = DataLoader(train_dataset,
batch_size=32,
shuffle=True,
drop_last=True,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=eeg_collate_fn)
for i_batch, sample_batched in enumerate(train_loader):
sample_batched['signal'].to(device)
sample_batched['age'].to(device)
sample_batched['class_label'].to(device)
print(i_batch,
sample_batched['signal'].shape,
sample_batched['age'].shape,
sample_batched['class_label'].shape,
len(sample_batched['metadata']))
if i_batch > 3:
break
Current PyTorch device: cuda 0 torch.Size([32, 20, 12000]) torch.Size([32]) torch.Size([32]) 32 1 torch.Size([32, 20, 12000]) torch.Size([32]) torch.Size([32]) 32 2 torch.Size([32, 20, 12000]) torch.Size([32]) torch.Size([32]) 32 3 torch.Size([32, 20, 12000]) torch.Size([32]) torch.Size([32]) 32 4 torch.Size([32, 20, 12000]) torch.Size([32]) torch.Size([32]) 32
train_loader = DataLoader(train_dataset,
batch_size=32,
shuffle=True,
drop_last=True,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=eeg_collate_fn)
val_loader = DataLoader(val_dataset,
batch_size=32,
shuffle=False,
drop_last=False,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=eeg_collate_fn)
test_loader = DataLoader(test_dataset,
batch_size=32,
shuffle=False,
drop_last=False,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=eeg_collate_fn)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def visualize_network_tensorboard(model, name):
from torch.utils.tensorboard import SummaryWriter
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/' + nb_fname + '_' + name)
for batch_i, sample_batched in enumerate(train_loader):
# pull up the batch data
x = sample_batched['signal'].to(device)
age = sample_batched['age'].to(device)
target = sample_batched['class_label'].to(device)
# apply model on whole batch directly on device
writer.add_graph(model, (x, age))
output = model(x, age, print_shape=True)
break
writer.close()
class TinyCNN(nn.Module):
def __init__(self, n_input=20, n_output=3, stride=7, n_channel=64,
use_age=True, final_pool='average'):
super().__init__()
if final_pool not in {'average', 'max'}:
raise ValueError("final_pool must be set to one of ['average', 'max']")
self.use_age = use_age
self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=35, stride=stride)
self.bn1 = nn.BatchNorm1d(n_channel)
self.pool1 = nn.MaxPool1d(4)
self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=7)
self.bn2 = nn.BatchNorm1d(n_channel)
self.pool2 = nn.MaxPool1d(2)
if final_pool == 'average':
self.final_pool = nn.AdaptiveAvgPool1d(1)
elif final_pool == 'max':
self.final_pool = nn.AdaptiveMaxPool1d(1)
if self.use_age:
self.fc1 = nn.Linear(n_channel + 1, n_channel)
else:
self.fc1 = nn.Linear(n_channel, n_channel)
self.dropout = nn.Dropout(p=0.3)
self.bnfc1 = nn.BatchNorm1d(n_channel)
self.fc2 = nn.Linear(n_channel, n_output)
def reset_weights(self):
for m in self.modules():
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
def forward(self, x, age, print_shape=False):
# conv-bn-relu-pool
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.pool1(x)
x = self.conv2(x)
x = F.relu(self.bn2(x))
x = self.pool2(x)
if print_shape:
print('Shape right before squeezing:', x.shape)
x = self.final_pool(x).squeeze()
if self.use_age:
x = torch.cat((x, age.reshape(-1, 1)), dim=1)
# fc-bn-dropout-relu-fc
x = self.fc1(x)
x = self.bnfc1(x)
x = self.dropout(x)
x = F.relu(x)
x = self.fc2(x)
return x
# return F.log_softmax(x, dim=1)
def generate_TinyCNN():
return TinyCNN(n_input=train_dataset[0]['signal'].shape[0],
n_output=3, use_age=True, final_pool='max')
model = generate_TinyCNN()
model = model.to(device, dtype=torch.float32)
print(model)
print()
# tensorboard visualization
visualize_network_tensorboard(model, 'TinyCNN')
# number of parameters
n = count_parameters(model)
print(f'The Number of parameters of the model: {n:,}')
del model
TinyCNN( (conv1): Conv1d(20, 64, kernel_size=(35,), stride=(7,)) (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False) (conv2): Conv1d(64, 64, kernel_size=(7,), stride=(1,)) (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (final_pool): AdaptiveMaxPool1d(output_size=1) (fc1): Linear(in_features=65, out_features=64, bias=True) (dropout): Dropout(p=0.3, inplace=False) (bnfc1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (fc2): Linear(in_features=64, out_features=3, bias=True) )
C:\Users\IPIS-Minjae\anaconda3\envs\EEG_Project\lib\site-packages\torch\nn\functional.py:652: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ..\c10/core/TensorImpl.h:1156.) return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)
Shape right before squeezing: torch.Size([32, 64, 210]) The Number of parameters of the model: 78,403
class M5(nn.Module):
def __init__(self, n_input=20, n_output=3, stride=4, n_channel=256,
use_age=True, final_pool='average'):
super().__init__()
if final_pool not in {'average', 'max'}:
raise ValueError("final_pool must be set to one of ['average', 'max']")
self.use_age = use_age
self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=41, stride=stride)
self.bn1 = nn.BatchNorm1d(n_channel)
self.pool1 = nn.MaxPool1d(4)
self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=11)
self.bn2 = nn.BatchNorm1d(n_channel)
self.pool2 = nn.MaxPool1d(2)
self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=11)
self.bn3 = nn.BatchNorm1d(2 * n_channel)
self.pool3 = nn.MaxPool1d(2)
self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=11)
self.bn4 = nn.BatchNorm1d(2 * n_channel)
self.pool4 = nn.MaxPool1d(2)
self.conv5 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=11)
self.bn5 = nn.BatchNorm1d(2 * n_channel)
self.pool5 = nn.MaxPool1d(2)
if final_pool == 'average':
self.final_pool = nn.AdaptiveAvgPool1d(1)
elif final_pool == 'max':
self.final_pool = nn.AdaptiveMaxPool1d(1)
if self.use_age:
self.fc1 = nn.Linear(2 * n_channel + 1, 2 * n_channel)
else:
self.fc1 = nn.Linear(2 * n_channel, 2 * n_channel)
self.dropout = nn.Dropout(p=0.3)
self.bnfc1 = nn.BatchNorm1d(2 * n_channel)
self.fc2 = nn.Linear(2 * n_channel, n_output)
def reset_weights(self):
for m in self.modules():
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
def forward(self, x, age, print_shape=False):
# conv-bn-relu-pool
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.pool1(x)
x = self.conv2(x)
x = F.relu(self.bn2(x))
x = self.pool2(x)
x = self.conv3(x)
x = F.relu(self.bn3(x))
x = self.pool3(x)
x = self.conv4(x)
x = F.relu(self.bn4(x))
x = self.pool4(x)
x = self.conv5(x)
x = F.relu(self.bn5(x))
x = self.pool5(x)
if print_shape:
print('Shape right before squeezing:', x.shape)
x = self.final_pool(x).squeeze()
if self.use_age:
x = torch.cat((x, age.reshape(-1, 1)), dim=1)
# fc-bn-dropout-relu-fc
x = self.fc1(x)
x = self.bnfc1(x)
x = self.dropout(x)
x = F.relu(x)
x = self.fc2(x)
return x
# return F.log_softmax(x, dim=1)
def generate_M5():
return M5(n_input=train_dataset[0]['signal'].shape[0],
n_output=3, use_age=True, final_pool='max')
model = generate_M5()
model = model.to(device, dtype=torch.float32)
print(model)
print()
# tensorboard visualization
visualize_network_tensorboard(model, 'M5')
# number of parameters
n = count_parameters(model)
print(f'The Number of parameters of the model: {n:,}')
del model
M5( (conv1): Conv1d(20, 256, kernel_size=(41,), stride=(4,)) (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False) (conv2): Conv1d(256, 256, kernel_size=(11,), stride=(1,)) (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv3): Conv1d(256, 512, kernel_size=(11,), stride=(1,)) (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv4): Conv1d(512, 512, kernel_size=(11,), stride=(1,)) (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool4): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv5): Conv1d(512, 512, kernel_size=(11,), stride=(1,)) (bn5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool5): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (final_pool): AdaptiveMaxPool1d(output_size=1) (fc1): Linear(in_features=513, out_features=512, bias=True) (dropout): Dropout(p=0.3, inplace=False) (bnfc1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (fc2): Linear(in_features=512, out_features=3, bias=True) ) Shape right before squeezing: torch.Size([32, 512, 37]) The Number of parameters of the model: 8,411,651
def generate_M5_no_age():
return M5(n_input=train_dataset[0]['signal'].shape[0],
n_output=3, use_age=False, final_pool='max')
model = generate_M5_no_age()
model = model.to(device, dtype=torch.float32)
print(model)
print()
# tensorboard visualization
visualize_network_tensorboard(model, 'M5-no-age')
# number of parameters
n = count_parameters(model)
print(f'The Number of parameters of the model: {n:,}')
del model
M5( (conv1): Conv1d(20, 256, kernel_size=(41,), stride=(4,)) (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False) (conv2): Conv1d(256, 256, kernel_size=(11,), stride=(1,)) (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv3): Conv1d(256, 512, kernel_size=(11,), stride=(1,)) (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv4): Conv1d(512, 512, kernel_size=(11,), stride=(1,)) (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool4): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv5): Conv1d(512, 512, kernel_size=(11,), stride=(1,)) (bn5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pool5): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (final_pool): AdaptiveMaxPool1d(output_size=1) (fc1): Linear(in_features=512, out_features=512, bias=True) (dropout): Dropout(p=0.3, inplace=False) (bnfc1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (fc2): Linear(in_features=512, out_features=3, bias=True) ) Shape right before squeezing: torch.Size([32, 512, 37]) The Number of parameters of the model: 8,411,139
class BasicResBlock(nn.Module):
expansion: int = 1
def __init__(self, c_in, c_out, kernel_size, stride) -> None:
super().__init__()
self.conv1 = nn.Conv1d(in_channels=c_in, out_channels=c_out,
kernel_size=kernel_size, stride=stride,
padding=kernel_size//2, bias=False)
self.bn1 = nn.BatchNorm1d(c_out)
self.conv2 = nn.Conv1d(in_channels=c_out, out_channels=c_out,
kernel_size=kernel_size, stride=1,
padding=kernel_size//2, bias=False)
self.bn2 = nn.BatchNorm1d(c_out)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
if stride != 1 or c_in != c_out:
self.downsample = nn.Sequential(
nn.Conv1d(in_channels=c_in, out_channels=c_out,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm1d(c_out)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
if self.downsample is not None:
identity = self.downsample(identity)
x = self.relu(x + identity)
return x
class BottleneckBlock(nn.Module):
expansion: int = 4
def __init__(self, c_in, c_out, kernel_size, stride) -> None:
super().__init__()
width = c_out
self.conv1 = nn.Conv1d(in_channels=c_in, out_channels=width,
kernel_size=1, stride=1, bias=False)
self.bn1 = nn.BatchNorm1d(width)
self.conv2 = nn.Conv1d(in_channels=width, out_channels=width,
kernel_size=kernel_size, stride=stride,
padding=kernel_size//2, bias=False)
self.bn2 = nn.BatchNorm1d(width)
self.conv3 = nn.Conv1d(in_channels=width, out_channels=c_out*self.expansion,
kernel_size=1, stride=1, bias=False)
self.bn3 = nn.BatchNorm1d(c_out*self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
if stride != 1 or c_in != c_out*self.expansion:
self.downsample = nn.Sequential(
nn.Conv1d(in_channels=c_in, out_channels=c_out*self.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm1d(c_out*self.expansion)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
if self.downsample is not None:
identity = self.downsample(identity)
x = self.relu(x + identity)
return x
class ResNet(nn.Module):
def __init__(self,
block: Type[Union[BasicResBlock, BottleneckBlock]],
conv_layers: List[int],
n_fc: int,
n_input=20,
n_output=3,
n_start=64,
kernel_size=9,
use_age=True,
final_pool='average') -> None:
super().__init__()
if final_pool not in {'average', 'max'}:
raise ValueError("final_pool must be set to one of ['average', 'max']")
self.c_current = n_start
self.use_age = use_age
self.input_stage = nn.Sequential(
nn.Conv1d(in_channels=n_input, out_channels=n_start,
kernel_size=kernel_size*3, stride=2,
padding=(kernel_size*3)//2, bias=False),
nn.BatchNorm1d(n_start),
nn.ReLU(),
nn.MaxPool1d(kernel_size=3)
)
self.conv_stage1 = self._make_conv_layer(block, conv_layers[0], n_start, kernel_size, stride=5)
self.conv_stage2 = self._make_conv_layer(block, conv_layers[1], n_start*2, kernel_size, stride=5)
self.conv_stage3 = self._make_conv_layer(block, conv_layers[2], n_start*4, kernel_size, stride=5)
self.conv_stage4 = self._make_conv_layer(block, conv_layers[3], n_start*8, kernel_size, stride=5)
if final_pool == 'average':
self.final_pool = nn.AdaptiveAvgPool1d(1)
elif final_pool == 'max':
self.final_pool = nn.AdaptiveMaxPool1d(1)
fc_layers = []
if self.use_age:
self.c_current = self.c_current + 1
for l in range(n_fc):
layer = nn.Sequential(nn.Linear(self.c_current, self.c_current // 2, bias=False),
nn.Dropout(p=0.1),
nn.BatchNorm1d(self.c_current // 2),
nn.ReLU())
self.c_current = self.c_current // 2
fc_layers.append(layer)
fc_layers.append(nn.Linear(self.c_current, n_output))
self.fc_stage = nn.Sequential(*fc_layers)
def reset_weights(self):
for m in self.modules():
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
def _make_conv_layer(self, block: Type[Union[BasicResBlock, BottleneckBlock]],
n_block: int, c_out: int, kernel_size: int, stride: int = 1) -> nn.Sequential:
layers = []
c_in = self.c_current
layers.append(block(c_in, c_out, kernel_size, stride=1))
c_in = c_out * block.expansion
self.c_current = c_in
for _ in range(1, n_block):
layers.append(block(c_in, c_out, kernel_size, stride=1))
layers.append(nn.MaxPool1d(kernel_size=stride))
return nn.Sequential(*layers)
def forward(self, x, age, print_shape=False):
x = self.input_stage(x)
x = self.conv_stage1(x)
x = self.conv_stage2(x)
x = self.conv_stage3(x)
x = self.conv_stage4(x)
if print_shape:
print('Shape right before squeezing:', x.shape)
x = self.final_pool(x).squeeze()
if self.use_age:
x = torch.cat((x, age.reshape(-1, 1)), dim=1)
x = self.fc_stage(x)
return x
# return F.log_softmax(x, dim=2)
def generate_ResNet():
return ResNet(block=BottleneckBlock,
conv_layers=[2, 2, 2, 2],
n_fc=3,
n_input=train_dataset[0]['signal'].shape[0],
n_output=3,
n_start=64,
kernel_size=9,
use_age=True,
final_pool='max')
model = generate_ResNet()
model = model.to(device, dtype=torch.float32)
print(model)
print()
# tensorboard visualization
visualize_network_tensorboard(model, '1D-ResNet')
# number of parameters
n = count_parameters(model)
print(f'The Number of parameters of the model: {n:,}')
del model
ResNet(
(input_stage): Sequential(
(0): Conv1d(20, 64, kernel_size=(27,), stride=(2,), padding=(13,), bias=False)
(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage1): Sequential(
(0): BottleneckBlock(
(conv1): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(64, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv1d(64, 256, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BottleneckBlock(
(conv1): Conv1d(256, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(64, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage2): Sequential(
(0): BottleneckBlock(
(conv1): Conv1d(256, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(128, 512, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv1d(256, 512, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BottleneckBlock(
(conv1): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(128, 512, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage3): Sequential(
(0): BottleneckBlock(
(conv1): Conv1d(512, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(256, 256, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(256, 1024, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv1d(512, 1024, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BottleneckBlock(
(conv1): Conv1d(1024, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(256, 256, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(256, 1024, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage4): Sequential(
(0): BottleneckBlock(
(conv1): Conv1d(1024, 512, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(512, 512, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(512, 2048, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv1d(1024, 2048, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BottleneckBlock(
(conv1): Conv1d(2048, 512, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(512, 512, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(512, 2048, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(final_pool): AdaptiveMaxPool1d(output_size=1)
(fc_stage): Sequential(
(0): Sequential(
(0): Linear(in_features=2049, out_features=1024, bias=False)
(1): Dropout(p=0.1, inplace=False)
(2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(1): Sequential(
(0): Linear(in_features=1024, out_features=512, bias=False)
(1): Dropout(p=0.1, inplace=False)
(2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(2): Sequential(
(0): Linear(in_features=512, out_features=256, bias=False)
(1): Dropout(p=0.1, inplace=False)
(2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(3): Linear(in_features=256, out_features=3, bias=True)
)
)
Shape right before squeezing: torch.Size([32, 2048, 3])
The Number of parameters of the model: 16,729,219
def generate_ResNet_no_age():
return ResNet(block=BottleneckBlock,
conv_layers=[2, 2, 2, 2],
n_fc=3,
n_input=train_dataset[0]['signal'].shape[0],
n_output=3,
n_start=64,
kernel_size=9,
use_age=False,
final_pool='max')
model = generate_ResNet_no_age()
model = model.to(device, dtype=torch.float32)
print(model)
print()
# tensorboard visualization
visualize_network_tensorboard(model, '1D-ResNet-no-age')
# number of parameters
n = count_parameters(model)
print(f'The Number of parameters of the model: {n:,}')
del model
ResNet(
(input_stage): Sequential(
(0): Conv1d(20, 64, kernel_size=(27,), stride=(2,), padding=(13,), bias=False)
(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage1): Sequential(
(0): BottleneckBlock(
(conv1): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(64, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv1d(64, 256, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BottleneckBlock(
(conv1): Conv1d(256, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(64, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage2): Sequential(
(0): BottleneckBlock(
(conv1): Conv1d(256, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(128, 512, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv1d(256, 512, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BottleneckBlock(
(conv1): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(128, 512, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage3): Sequential(
(0): BottleneckBlock(
(conv1): Conv1d(512, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(256, 256, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(256, 1024, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv1d(512, 1024, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BottleneckBlock(
(conv1): Conv1d(1024, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(256, 256, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(256, 1024, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage4): Sequential(
(0): BottleneckBlock(
(conv1): Conv1d(1024, 512, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(512, 512, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(512, 2048, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv1d(1024, 2048, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BottleneckBlock(
(conv1): Conv1d(2048, 512, kernel_size=(1,), stride=(1,), bias=False)
(bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(512, 512, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv1d(512, 2048, kernel_size=(1,), stride=(1,), bias=False)
(bn3): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(final_pool): AdaptiveMaxPool1d(output_size=1)
(fc_stage): Sequential(
(0): Sequential(
(0): Linear(in_features=2048, out_features=1024, bias=False)
(1): Dropout(p=0.1, inplace=False)
(2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(1): Sequential(
(0): Linear(in_features=1024, out_features=512, bias=False)
(1): Dropout(p=0.1, inplace=False)
(2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(2): Sequential(
(0): Linear(in_features=512, out_features=256, bias=False)
(1): Dropout(p=0.1, inplace=False)
(2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(3): Linear(in_features=256, out_features=3, bias=True)
)
)
Shape right before squeezing: torch.Size([32, 2048, 3])
The Number of parameters of the model: 16,728,195
def generate_TinyResNet():
return ResNet(block=BasicResBlock,
conv_layers=[1, 1, 1, 1],
n_fc=3,
n_input=train_dataset[0]['signal'].shape[0],
n_output=3,
n_start=64,
kernel_size=9,
use_age=False)
model = generate_TinyResNet()
model = model.to(device, dtype=torch.float32)
print(model)
print()
n = count_parameters(model)
print(f'The Number of parameters of the model: {n:,}')
del model
ResNet(
(input_stage): Sequential(
(0): Conv1d(20, 64, kernel_size=(27,), stride=(2,), padding=(13,), bias=False)
(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage1): Sequential(
(0): BasicResBlock(
(conv1): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage2): Sequential(
(0): BasicResBlock(
(conv1): Conv1d(64, 128, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(128, 128, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv1d(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage3): Sequential(
(0): BasicResBlock(
(conv1): Conv1d(128, 256, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(256, 256, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv1d(128, 256, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(conv_stage4): Sequential(
(0): BasicResBlock(
(conv1): Conv1d(256, 512, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv1d(512, 512, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
(bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv1d(256, 512, kernel_size=(1,), stride=(1,), bias=False)
(1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
)
(final_pool): AdaptiveAvgPool1d(output_size=1)
(fc_stage): Sequential(
(0): Sequential(
(0): Linear(in_features=512, out_features=256, bias=False)
(1): Dropout(p=0.1, inplace=False)
(2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(1): Sequential(
(0): Linear(in_features=256, out_features=128, bias=False)
(1): Dropout(p=0.1, inplace=False)
(2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(2): Sequential(
(0): Linear(in_features=128, out_features=64, bias=False)
(1): Dropout(p=0.1, inplace=False)
(2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU()
)
(3): Linear(in_features=64, out_features=3, bias=True)
)
)
The Number of parameters of the model: 5,104,067
def check_val_accuracy(model, repeat=1):
model.eval()
correct, total = (0, 0)
C = len(class_label_to_type)
val_confusion = np.zeros((C, C), dtype=np.int32)
for k in range(repeat):
for sample_batched in val_loader:
# pull up the data
x = sample_batched['signal'].to(device)
age = sample_batched['age'].to(device)
target = sample_batched['class_label'].to(device)
# apply model on whole batch directly on device
output = model(x, age)
pred = F.log_softmax(output, dim=1)
# val accuracy
pred = pred.argmax(dim=-1)
correct += pred.squeeze().eq(target).sum().item()
total += pred.shape[0]
# confusion matrix
val_confusion += calculate_confusion_matrix(pred, target)
val_accuracy = 100.0 * correct / total
return (val_accuracy, val_confusion)
def check_test_accuracy(model, repeat=1):
model.eval()
correct, total = (0, 0)
C = len(class_label_to_type)
test_confusion = np.zeros((C, C), dtype=np.int32)
test_debug = {data['metadata']['serial']:
{'GT': data['class_label'].item(),
'Acc': 0,
'Pred': [0] * C} for data in test_dataset}
score = None
target = None
for k in range(repeat):
for sample_batched in test_loader:
# pull up the data
x = sample_batched['signal'].to(device)
age = sample_batched['age'].to(device)
y = sample_batched['class_label'].to(device)
# apply model on whole batch directly on device
output = model(x, age)
s = F.softmax(output, dim=1)
pred = F.log_softmax(output, dim=1)
# test accuracy
pred = pred.argmax(dim=-1)
correct += pred.squeeze().eq(y).sum().item()
total += pred.shape[0]
if score is None:
score = s.detach().cpu().numpy()
target = y.detach().cpu().numpy()
else:
score = np.concatenate((score, s.detach().cpu().numpy()), axis=0)
target = np.concatenate((target, y.detach().cpu().numpy()), axis=0)
# confusion matrix
test_confusion += calculate_confusion_matrix(pred, y)
# test debug
for n in range(pred.shape[0]):
serial = sample_batched['metadata'][n]['serial']
test_debug[serial]['edfname'] = sample_batched['metadata'][n]['edfname']
test_debug[serial]['Pred'][pred[n].item()] += 1
acc = test_debug[serial]['Pred'][y[n].item()] / np.sum(test_debug[serial]['Pred']) * 100
test_debug[serial]['Acc'] = f'{acc:>6.02f}%'
test_accuracy = 100.0 * correct / total
return (test_accuracy, test_confusion, test_debug, score, target)
def calculate_confusion_matrix(pred, target):
N = target.shape[0]
C = len(class_label_to_type)
confusion = np.zeros((C, C), dtype=np.int32)
for i in range(N):
r = target[i]
c = pred[i]
confusion[r, c] += 1
return confusion
def draw_loss_plot(losses, lr_decay_step):
plt.style.use('default') # default, ggplot, fivethirtyeight, classic
fig = plt.figure(num=1, clear=True, figsize=(8.0, 3.0), constrained_layout=True)
ax = fig.add_subplot(1, 1, 1)
N = len(losses)
x = np.arange(1, N + 1)
ax.plot(x, losses)
x2 = np.arange(lr_decay_step, N, lr_decay_step)
ax.vlines(x2, 0, 1, transform=ax.get_xaxis_transform(),
colors='m', alpha=0.5, linestyle='solid')
# ax.vlines([1, N], 0, 1, transform=ax.get_xaxis_transform(),
# colors='k', alpha=0.7, linestyle='solid')
ax.set_xlim(left=0)
ax.set_title('Loss Plot')
ax.set_xlabel('Iteration')
ax.set_ylabel('Training Loss')
plt.show()
fig.clear()
plt.close(fig)
def draw_accuracy_history(train_acc_history, val_acc_history, history_interval, lr_decay_step):
plt.style.use('default') # default, ggplot, fivethirtyeight, classic
fig = plt.figure(num=1, clear=True, figsize=(8.0, 3.0), constrained_layout=True)
ax = fig.add_subplot(1, 1, 1)
N = len(train_acc_history) * history_interval
x = np.arange(history_interval, N + 1, history_interval)
ax.plot(x, train_acc_history, 'r-', label='Train accuracy')
ax.plot(x, val_acc_history, 'b-', label='Validation accuracy')
x2 = np.arange(lr_decay_step, N + 1, lr_decay_step)
ax.vlines(x2, 0, 1, transform=ax.get_xaxis_transform(),
colors='m', alpha=0.5, linestyle='solid')
# ax.vlines([history_interval, N], 0, 1, transform=ax.get_xaxis_transform(),
# colors='k', alpha=0.7, linestyle='solid')
ax.set_xlim(left=0)
ax.legend(loc='lower right')
ax.set_title('Accuracy Plot during Training')
ax.set_xlabel('Iteration')
ax.set_ylabel('Accuracy (%)')
plt.show()
fig.clear()
plt.close(fig)
def draw_confusion(confusion):
C = len(class_label_to_type)
plt.style.use('default') # default, ggplot, fivethirtyeight, classic
plt.rcParams['image.cmap'] = 'jet' # 'nipy_spectral'
fig = plt.figure(num=1, clear=True, figsize=(4.0, 4.0), constrained_layout=True)
ax = fig.add_subplot(1, 1, 1)
im = ax.imshow(confusion, alpha=0.8)
ax.set_xticks(np.arange(C))
ax.set_yticks(np.arange(C))
ax.set_xticklabels(class_label_to_type)
ax.set_yticklabels(class_label_to_type)
for r in range(C):
for c in range(C):
text = ax.text(c, r, confusion[r, c],
ha="center", va="center", color='k')
ax.set_title('Confusion Matrix')
ax.set_xlabel('Prediction')
ax.set_ylabel('Ground Truth')
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.show()
fig.clear()
plt.close(fig)
def draw_roc_curve(score, target):
plt.style.use('default') # default, ggplot, fivethirtyeight, classic
# Binarize the output
n_classes = len(class_label_to_type)
target = label_binarize(target, classes=np.arange(n_classes))
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(target[:, i], score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(target.ravel(), score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
# draw class-agnostic ROC curve
fig = plt.figure(num=1, clear=True, figsize=(8.5, 4.0), constrained_layout=True)
ax = fig.add_subplot(1, 2, 1)
lw = 1.5
colors = cycle(['limegreen', 'mediumpurple', 'darkorange',
'dodgerblue', 'lightcoral', 'goldenrod',
'indigo', 'darkgreen', 'navy', 'brown'])
for i, color in zip(range(n_classes), colors):
ax.plot(fpr[i], tpr[i], color=color, lw=lw,
label='{0} (area = {1:0.2f})'
''.format(class_label_to_type[i], roc_auc[i]))
ax.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('Class-Wise ROC Curves')
ax.legend(loc="lower right")
# Plot class-aware ROC curves
ax = fig.add_subplot(1, 2, 2)
plt.plot(fpr["micro"], tpr["micro"],
label='micro-average (area = {0:0.2f})'
''.format(roc_auc["micro"]),
color='deeppink', linestyle='-', linewidth=lw)
plt.plot(fpr["macro"], tpr["macro"],
label='macro-average (area = {0:0.2f})'
''.format(roc_auc["macro"]),
color='navy', linestyle='-', linewidth=lw)
ax.plot([0, 1], [0, 1], 'k--', lw=lw)
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('Class-Agnostic ROC Curves')
ax.legend(loc="lower right")
plt.show()
fig.clear()
plt.close(fig)
def learning_rate_search(model, min_log_lr, max_log_lr, trials, iters):
learning_rate_record = []
for t in tqdm(range(trials)):
log_lr = np.random.uniform(min_log_lr, max_log_lr)
lr = 10 ** log_lr
model.reset_weights()
model.train()
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.0001)
correct, total = (0, 0)
i = 1
while True:
for sample_batched in train_loader:
x = sample_batched['signal'].to(device)
age = sample_batched['age'].to(device)
target = sample_batched['class_label'].to(device)
output = model(x, age)
pred = F.log_softmax(output, dim=1)
loss = F.nll_loss(pred, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
pred = pred.argmax(dim=-1)
correct += pred.squeeze().eq(target).sum().item()
total += pred.shape[0]
i += 1
if i >= iters:
break
if i >= iters:
break
train_accuracy = 100.0 * correct / total
# Train accuracy for the final epoch is stored
learning_rate_record.append((log_lr, train_accuracy))
return learning_rate_record
def draw_learning_rate_record(learning_rate_record):
plt.style.use('default') # default, ggplot, fivethirtyeight, classic
fig = plt.figure(num=1, clear=True, constrained_layout=True) # figsize=(6.0, 6.0)
ax = fig.add_subplot(1, 1, 1)
ax.set_title('Learning Rate Search')
ax.set_xlabel('Learning rate in log-scale')
ax.set_ylabel('Train accuracy')
for log_lr, val_accuracy in learning_rate_record:
ax.scatter(log_lr, val_accuracy, c='r',
alpha=0.5, edgecolors='none')
plt.show()
fig.clear()
plt.close(fig)
model_pool = []
model_dict = {}
model_dict['name'] = 'TinyCNN'
model_dict['generator'] = generate_TinyCNN
model_dict['lr_start'] = None
model_pool.append(model_dict)
model_dict = {}
model_dict['name'] = 'M5'
model_dict['generator'] = generate_M5
model_dict['lr_start'] = None
model_pool.append(model_dict)
model_dict = {}
model_dict['name'] = 'M5-no-age'
model_dict['generator'] = generate_M5_no_age
model_dict['lr_start'] = None
model_pool.append(model_dict)
model_dict = {}
model_dict['name'] = '1D-ResNet'
model_dict['generator'] = generate_ResNet
model_dict['lr_start'] = None
model_pool.append(model_dict)
model_dict = {}
model_dict['name'] = '1D-ResNet-no-age'
model_dict['generator'] = generate_ResNet_no_age
model_dict['lr_start'] = None
model_pool.append(model_dict)
model_dict = {}
model_dict['name'] = '1D-TinyResNet-no-age'
model_dict['generator'] = generate_TinyCNN
model_dict['lr_start'] = None
model_pool.append(model_dict)
pprint.pp(model_pool, width=150)
[{'name': 'TinyCNN', 'generator': <function generate_TinyCNN at 0x000001B10CD8F1F0>, 'lr_start': None},
{'name': 'M5', 'generator': <function generate_M5 at 0x000001B10FFF28B0>, 'lr_start': None},
{'name': 'M5-no-age', 'generator': <function generate_M5_no_age at 0x000001B10FFF2EE0>, 'lr_start': None},
{'name': '1D-ResNet', 'generator': <function generate_ResNet at 0x000001B10FF3F280>, 'lr_start': None},
{'name': '1D-ResNet-no-age', 'generator': <function generate_ResNet_no_age at 0x000001B11005B310>, 'lr_start': None},
{'name': '1D-TinyResNet-no-age', 'generator': <function generate_TinyCNN at 0x000001B10CD8F1F0>, 'lr_start': None}]
for model_dict in model_pool:
if model_dict['lr_start'] is None:
print(f'{model_dict["name"]} LR searching..')
model = model_dict['generator']().to(device)
model.train()
record = learning_rate_search(model, min_log_lr=-4.5, max_log_lr=-1.4,
trials=300, iters=30)
draw_learning_rate_record(record)
best_log_lr = record[np.argmax(np.array([v for lr, v in record]))][0]
model_dict['lr_start'] = 10 ** best_log_lr
print(f'best lr {model_dict["lr_start"]:.5e} / log_lr {best_log_lr}')
else:
print(f'{model_dict["name"]}: {model_dict["lr_start"]:.5e}')
print('-' * 100)
pprint.pp(model_pool, width=150)
TinyCNN LR searching..
best lr 8.79702e-03 / log_lr -2.0556643409724886 ---------------------------------------------------------------------------------------------------- M5 LR searching..
best lr 1.40040e-04 / log_lr -3.8537494114231405 ---------------------------------------------------------------------------------------------------- M5-no-age LR searching..
best lr 2.96744e-03 / log_lr -2.527618442281792 ---------------------------------------------------------------------------------------------------- 1D-ResNet LR searching..
best lr 6.87017e-03 / log_lr -2.1630322083007694 ---------------------------------------------------------------------------------------------------- 1D-ResNet-no-age LR searching..
best lr 9.95003e-03 / log_lr -2.002175463093898 ---------------------------------------------------------------------------------------------------- 1D-TinyResNet-no-age LR searching..
best lr 1.98976e-02 / log_lr -1.7011992936537332
----------------------------------------------------------------------------------------------------
[{'name': 'TinyCNN', 'generator': <function generate_TinyCNN at 0x000001B10CD8F1F0>, 'lr_start': 0.008797021614158067},
{'name': 'M5', 'generator': <function generate_M5 at 0x000001B10FFF28B0>, 'lr_start': 0.00014003951196149358},
{'name': 'M5-no-age', 'generator': <function generate_M5_no_age at 0x000001B10FFF2EE0>, 'lr_start': 0.0029674373433992507},
{'name': '1D-ResNet', 'generator': <function generate_ResNet at 0x000001B10FF3F280>, 'lr_start': 0.006870174872913063},
{'name': '1D-ResNet-no-age', 'generator': <function generate_ResNet_no_age at 0x000001B11005B310>, 'lr_start': 0.009950033361741028},
{'name': '1D-TinyResNet-no-age', 'generator': <function generate_TinyCNN at 0x000001B10CD8F1F0>, 'lr_start': 0.019897600472783057}]
file_check = True
save_model = True
draw_result = True
# log path
log_path = f'history_temp/{nb_fname}/'
os.makedirs(log_path, exist_ok=True)
# train iterations
n_repeats = 1
n_iters = 12500 * ((200 * 60) // crop_length)
history_interval = n_iters // 300
lr_decay_step = round(n_iters / 2.5)
print(f'-- The number of iterations for each model: {n_iters}--')
# progress bar
pbar = tqdm(total=len(model_pool) * n_iters * n_repeats)
# train process on model_pool
for model_dict in model_pool:
print(f'{"*"*40} {model_dict["name"]} train starts {"*"*40}')
best_r_test_acc = 0
for r in range(n_repeats):
if file_check:
endwith = '' if n_repeats == 1 else f'_r{r:02d}'
path = os.path.join(log_path, f'{model_dict["name"]}_log{endwith}')
if os.path.isfile(path):
log_dict = torch.load(path)
# loss and accuracy plots
if draw_result:
draw_loss_plot(log_dict["losses"], log_dict["lr_decay_step"])
draw_accuracy_history(log_dict["train_acc_history"], log_dict["val_acc_history"],
log_dict["history_interval"], log_dict["lr_decay_step"])
script = f'- {r:02d} train accuracy {log_dict["train_acc_history"][-1]:.2f}%, '\
f'best / last test accuracies {log_dict["best_test_accuracy"]:.2f}% / {log_dict["last_test_accuracy"]:.2f}% - file exists'
print()
print(script)
print()
pbar.update(n_iters)
continue
# load the model dict
model = model_dict['generator']().to(device)
model.train()
lr_start = model_dict['lr_start']
# configure for training
optimizer = optim.AdamW(model.parameters(), lr=lr_start, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)
# log during training
losses = []
train_acc_history = []
val_acc_history = []
best_val_acc = 0
correct, total = (0, 0)
i = 1
while True:
for sample_batched in train_loader:
model.train()
# load the data
x = sample_batched['signal'].to(device)
age = sample_batched['age'].to(device)
target = sample_batched['class_label'].to(device)
# forward pass
output = model(x, age)
pred = F.log_softmax(output, dim=1)
loss = F.nll_loss(pred, target)
# backward and update
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
# train accuracy
pred = pred.argmax(dim=-1)
correct += pred.squeeze().eq(target).sum().item()
total += pred.shape[0]
# log
losses.append(loss.item())
pbar.update(1)
i += 1
# history
if i % history_interval == 0:
train_acc = 100.0 * correct / total
train_acc_history.append(train_acc)
correct, total = (0, 0)
val_acc, _ = check_val_accuracy(model, repeat=5)
val_acc_history.append(val_acc)
if best_val_acc < val_acc:
best_val_acc = val_acc
best_model_state = deepcopy(model.state_dict())
if i >= n_iters:
break
if i >= n_iters:
break
# loss and accuracy plots
if draw_result:
draw_loss_plot(losses, lr_decay_step)
draw_accuracy_history(train_acc_history, val_acc_history, history_interval, lr_decay_step)
# calculate the test accuracies for best and last models
last_model_state = deepcopy(model.state_dict())
last_test_acc, last_test_confusion, last_test_debug, _, _ = check_test_accuracy(model, repeat=30)
model.load_state_dict(best_model_state)
best_test_acc, best_test_confusion, best_test_debug, _, _ = check_test_accuracy(model, repeat=30)
# save the model if it is best among repeatedly trained models
if save_model and best_r_test_acc < max(last_test_acc, best_test_acc):
best_r_test_acc = max(last_test_acc, best_test_acc)
model_state = last_model_state if best_test_acc < last_test_acc else best_model_state
path = os.path.join(log_path, f'{model_dict["name"]}')
torch.save(model_state, path)
# leave the log
endwith = '' if n_repeats == 1 else f'_r{r:02d}'
path = os.path.join(log_path, f'{model_dict["name"]}_log{endwith}')
log_dict = {}
log_dict['model'] = model_dict['name']
log_dict['starting_lr'] = lr_start
log_dict['final_lr'] = optimizer.param_groups[-1]["lr"]
log_dict['history_interval'] = history_interval
log_dict['lr_decay_step'] = lr_decay_step
log_dict['losses'] = losses
log_dict['train_acc_history'] = train_acc_history
log_dict['val_acc_history'] = val_acc_history
log_dict['best_test_accuracy'] = best_test_acc
log_dict['best_test_confusion'] = best_test_confusion
log_dict['best_test_debug'] = best_test_debug
log_dict['last_test_accuracy'] = last_test_acc
log_dict['last_test_confusion'] = last_test_confusion
log_dict['last_test_debug'] = last_test_debug
torch.save(log_dict, path)
script = f'- {r:02d} train accuracy {train_acc:.2f}%, '\
f'best / last test accuracies {best_test_acc:.2f}% / {last_test_acc:.2f}%'
print()
print(script)
print()
if draw_result and save_model:
model = model_dict['generator']().to(device)
path = os.path.join(log_path, f'{model_dict["name"]}')
model.load_state_dict(torch.load(path))
temp_result = check_test_accuracy(model, repeat=30)
test_acc, test_confusion, test_debug, score, target = temp_result
draw_roc_curve(score, target)
draw_confusion(test_confusion)
print('\n' * 2)
print()
-- The number of iterations for each model: 12500--
**************************************** TinyCNN train starts ****************************************
- 00 train accuracy 84.68%, best / last test accuracies 59.10% / 58.88%
**************************************** M5 train starts ****************************************
- 00 train accuracy 99.77%, best / last test accuracies 58.24% / 57.50%
**************************************** M5-no-age train starts ****************************************
- 00 train accuracy 97.03%, best / last test accuracies 57.18% / 55.99%
**************************************** 1D-ResNet train starts ****************************************
- 00 train accuracy 93.60%, best / last test accuracies 59.78% / 59.84%
**************************************** 1D-ResNet-no-age train starts ****************************************
- 00 train accuracy 91.16%, best / last test accuracies 57.12% / 56.28%
**************************************** 1D-TinyResNet-no-age train starts ****************************************
- 00 train accuracy 80.72%, best / last test accuracies 59.29% / 58.37%